/*
 * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.ir.generator.print.symbol

import org.jetbrains.kotlin.generators.tree.*
import org.jetbrains.kotlin.generators.tree.AbstractField.SymbolFieldRole
import org.jetbrains.kotlin.generators.tree.printer.*
import org.jetbrains.kotlin.ir.generator.Model
import org.jetbrains.kotlin.ir.generator.TREE_GENERATOR_README
import org.jetbrains.kotlin.ir.generator.declaredSymbolRemapperType
import org.jetbrains.kotlin.ir.generator.emptySymbolRemapperType
import org.jetbrains.kotlin.ir.generator.model.Element
import org.jetbrains.kotlin.ir.generator.model.symbol.Symbol
import org.jetbrains.kotlin.ir.generator.model.symbol.findFieldsWithSymbols
import org.jetbrains.kotlin.ir.generator.model.symbol.symbolRemapperMethodName
import org.jetbrains.kotlin.ir.generator.referencedSymbolRemapperType
import org.jetbrains.kotlin.utils.addToStdlib.ifNotEmpty
import java.io.File

internal abstract class AbstractSymbolRemapperPrinter(
    private val printer: ImportCollectingPrinter,
    val elements: List<Element>,
    val roles: List<SymbolFieldRole>,
) {
    abstract val symbolRemapperType: ClassRef<*>

    abstract val implementationKind: ImplementationKind

    open val symbolRemapperSuperTypes: List<ClassRef<*>>
        get() = emptyList()

    open val kDoc: String?
        get() = null

    protected open fun shouldPrintMethodForSymbol(symbolClass: Symbol, role: SymbolFieldRole): Boolean = true

    private fun ImportCollectingPrinter.printMethod(symbolClass: Symbol, returnType: Symbol, role: SymbolFieldRole) {
        val symbolParameter = FunctionParameter("symbol", symbolClass)
        printFunctionDeclaration(
            symbolRemapperMethodName(symbolClass, role),
            parameters = listOf(symbolParameter),
            returnType = returnType,
            override = symbolRemapperSuperTypes.isNotEmpty(),
        )
        printMethodImplementation(symbolParameter, symbolClass, role)
    }

    protected open fun ImportCollectingPrinter.printMethodImplementation(
        symbolParameter: FunctionParameter,
        symbolClass: Symbol,
        role: SymbolFieldRole
    ) {
        if (symbolClass.subElements.isNotEmpty()) {
            print(" = when (", symbolParameter.name, ")")
            printBlock {
                for (subSymbol in symbolClass.subElements) {
                    println("is ", subSymbol.render(), " -> ", symbolRemapperMethodName(subSymbol, role), "(", symbolParameter.name, ")")
                }
            }
        } else {
            println()
        }
    }

    protected open fun ImportCollectingPrinter.printAdditionalDeclarations() {}

    fun printSymbolRemapper() {
        printer.run {
            printKDoc(
                buildString {
                    kDoc?.let {
                        append(it)
                        appendLine()
                        appendLine()
                    }
                    append("Auto-generated by [${this@AbstractSymbolRemapperPrinter::class.qualifiedName}]")
                },
            )
            assert(symbolRemapperType.kind == implementationKind.typeKind) { "Type kind mismatch" }
            print(implementationKind.title, " ", symbolRemapperType.simpleName)
            symbolRemapperSuperTypes.ifNotEmpty {
                print(joinToString(prefix = " : ") { it.render() + if (it.kind == TypeKind.Class) "()" else "" })
            }
            printBlock {
                for (role in roles) {
                    val fieldsAndSymbols = findFieldsWithSymbols(elements, role)
                    val symbols = fieldsAndSymbols.keys.flatMap { it.elementDescendantsAndSelfDepthFirst() }.distinct()
                    for (symbolType in symbols) {
                        if (!shouldPrintMethodForSymbol(symbolType, role)) continue
                        val fields = symbolType.elementAncestorsAndSelfDepthFirst().flatMap { fieldsAndSymbols[it].orEmpty() }.toList()

                        // If this symbol type is used in some field as is, use it as the return type of the corresponding method.
                        // Otherwise, if this symbol type is only used as a subtype, set the return type to its most specific supertype
                        // used in any of the fields.
                        val mostSpecificReturnType = fields.fold(null) { acc: Symbol?, field ->
                            if (acc?.isSubclassOf(field.symbolType) == true) {
                                acc
                            } else {
                                field.symbolType
                            }
                        }!!

                        println()
                        if (symbolRemapperSuperTypes.isEmpty()) {
                            val kDoc = buildString {
                                append("Remaps symbols stored, e.g., in the following properties (not necessarily limited to those properties):")
                                for ((_, fieldName, _, element) in fields) {
                                    appendLine()
                                    append("- [${element.render()}.$fieldName]")
                                }
                            }
                            printKDoc(kDoc)
                        }
                        printMethod(symbolType, mostSpecificReturnType, role)
                    }
                }
                printAdditionalDeclarations()
            }
        }
    }
}

internal class DeclaredSymbolRemapperInterfacePrinter(
    printer: ImportCollectingPrinter,
    elements: List<Element>,
    override val symbolRemapperType: ClassRef<*>,
) : AbstractSymbolRemapperPrinter(printer, elements, roles = listOf(SymbolFieldRole.DECLARED)) {
    override val implementationKind: ImplementationKind
        get() = ImplementationKind.Interface

    override val kDoc: String
        get() = "Used to replace declarations' own symbols with new ones."
}

internal class ReferencedSymbolRemapperInterfacePrinter(
    printer: ImportCollectingPrinter,
    elements: List<Element>,
    override val symbolRemapperType: ClassRef<*>,
) : AbstractSymbolRemapperPrinter(printer, elements, roles = listOf(SymbolFieldRole.REFERENCED)) {
    override val implementationKind: ImplementationKind
        get() = ImplementationKind.Interface

    override val kDoc: String
        get() = "Used to replace symbols that represent references to declarations other than the symbol's owner."
}

internal class SymbolRemapperInterfacePrinter(
    printer: ImportCollectingPrinter,
    elements: List<Element>,
    override val symbolRemapperType: ClassRef<*>,
) : AbstractSymbolRemapperPrinter(printer, elements, roles = emptyList()) {
    override val implementationKind: ImplementationKind
        get() = ImplementationKind.Interface

    override val symbolRemapperSuperTypes: List<ClassRef<*>>
        get() = listOf(declaredSymbolRemapperType, referencedSymbolRemapperType)

    override fun ImportCollectingPrinter.printAdditionalDeclarations() {
        println()
        EmptySymbolRemapperPrinter(this, elements).printSymbolRemapper()
        println()
        print("companion object")
        printBlock {
            println("val EMPTY: ${symbolRemapperType.render()} = ${emptySymbolRemapperType.simpleName}()")
        }
    }
}

private class EmptySymbolRemapperPrinter(
    printer: ImportCollectingPrinter,
    elements: List<Element>,
) : AbstractSymbolRemapperPrinter(printer, elements, listOf(SymbolFieldRole.DECLARED, SymbolFieldRole.REFERENCED)) {

    override val symbolRemapperType = emptySymbolRemapperType

    override val symbolRemapperSuperTypes: List<ClassRef<*>>
        get() = listOf(org.jetbrains.kotlin.ir.generator.symbolRemapperType)

    override val implementationKind: ImplementationKind
        get() = ImplementationKind.OpenClass

    override val kDoc: String
        get() = "The default implementation of [${org.jetbrains.kotlin.ir.generator.symbolRemapperType.simpleName}]\n" +
                "that just keeps the old symbols everywhere."

    override fun shouldPrintMethodForSymbol(symbolClass: Symbol, role: SymbolFieldRole): Boolean {
        return symbolClass.subElements.isEmpty()
    }

    override fun ImportCollectingPrinter.printMethodImplementation(
        symbolParameter: FunctionParameter,
        symbolClass: Symbol,
        role: SymbolFieldRole,
    ) {
        println(" = ", symbolParameter.name)
    }
}

internal fun printSymbolRemapper(
    generationPath: File,
    model: Model,
    type: ClassRef<*>,
    makePrinter: (ImportCollectingPrinter, List<Element>, ClassRef<*>) -> AbstractSymbolRemapperPrinter,
) = printGeneratedType(generationPath, TREE_GENERATOR_README, type.packageName, type.simpleName) {
    makePrinter(this, model.elements, type).printSymbolRemapper()
}
